{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<center>\n",
    "<img src=\"../../img/ods_stickers.jpg\">\n",
    "## Открытый курс по машинному обучению\n",
    "<center>Автор материала: @vfdev"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import HTML\n",
    "from IPython.display import Latex, Math, display"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Итеративная стратификация данных с пересекающимися классами\n",
    "Iterative stratification of multi-label data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## О чем речь - TL;DR\n",
    "\n",
    "Есть данные, например, документы (`X`) и их теги (`y`). Естественно, что каждый документ может иметь сразу несколько тегов. Для обучения моделей важно уметь \"сбалансированно\" разбивать выборку на тренировочную и проверочную. \n",
    "Далее, рассмотрим один методов такого разбиения: iterative stratification и небольшую [библиотеку](https://github.com/trent-b/iterative-stratification), реализующую алгоритм на Python.\n",
    "\n",
    "**Исходные данные:**\n",
    "\n",
    "instance | tag_0 | tag_1 | tag_2 | ... | tag_9\n",
    "---|---|---|---|---|--- \n",
    "a | 0 | 0 | 1 | ... | 0\n",
    "b | 0 | 1 | 1 | ... | 1\n",
    "c | 1 | 0 | 1 | ... | 0\n",
    "... | ... | ... | ... | ...\n",
    "\n",
    "**Разбиение на тренировочную и проверочную выборки**\n",
    "\n",
    "<img src=\"../../img/iterative_stratification_of_multilabel_data_tldr.png\">\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Теория\n",
    "\n",
    "**Кросс-валидация** (кросс-проверка, скользящий контроль) используется для оценивания обобщающей способности алгоритмов и является также полезной процедурой для выявления переобучения. При этом, в зависимости от качества (сбалансированности) исходных данных, очень важен метод разбиения исходной выборки на обучающую и проверочную. В этом курсе мы видели несколько вариантов такого разбиения: \n",
    "- случайное разбиение (`sklearn.model_selection.KFold`) для сбалансированных данных\n",
    "- стратификационное разбиение (`sklearn.model_selection.StratifiedKFold`)\n",
    "  - сохраняет пропорции присутствия непересекающихся классов в обучающей и проверочной выборках\n",
    "- разбиение временных рядов (`sklearn.model_selection.TimeSeriesSplit`)\n",
    "\n",
    "Понятно, что в зависимости от типа данных следует применять то или иное разбиение. Итак, у нас есть данные, в которых целевая переменная $y$ является вектором $(y_0, y_1, ..., y_{K-1})$, где $y_i = \\{0, 1\\}$. Мы хотим разбить эти данные на две подвыборки, сохранив приблизительно одинаковое распределение классов. В этой тетрадке мы познакомимся с одним из методов стратификационного разбиение для данных с пересекающимися классами.\n",
    "\n",
    "**Что делать с тегами?** Варианты, как работать с данными с пересекающимися классами, могут быть разные. Хотя их рассмотрение и не входит в тему этого обзора, но приведем несколько примеров:\n",
    "- преобразование начальной задачи в проблему многоклассовой классификации: *теги -> непересекающиеся классы*\n",
    "  - Binary relevence: преобразует K тегов в K бинарных наборов данных, по одному набору данных на каждый тег\n",
    "  - Label Powerset: преобразует уникальные комбинации тегов в непересекающиеся классы, `(0, 1, 0, 1) -> 5`  \n",
    "- использование алгоритмов, способных работать с начальной задачей (с тегами) \n",
    "  - Решающие деревья\n",
    "  - kNN\n",
    "  - Нейронные сети\n",
    "  \n",
    "Более подробно об этом можно почитать по ссылкам:\n",
    "- [Multi-label classification](https://en.wikipedia.org/wiki/Multi-label_classification)\n",
    "- [Multi-Label Classification: An Overview](http://lpis.csd.auth.gr/publications/tsoumakas-ijdwm.pdf)\n",
    "- [Решающие правила для ансамбля из цепей\n",
    "вероятностных классификаторов при решении задач\n",
    "классификации с пересекающимися классами](http://jmlda.org/papers/doc/2016/JMLDA2016no3.pdf)\n",
    "- [Sklearn - Support multilabel](http://scikit-learn.org/stable/modules/multiclass.html)\n",
    "\n",
    "\n",
    "### Итеративная стратификация\n",
    "\n",
    "Это метод разбиения исходной выборки $D(X, y)$ с пересекающимися $K$ классами на $n$ непересекающихся наборов. \n",
    "Также как и в обычном методе многоклассовой стратификации, результат разбиение $S_1$, $S_2$ ... $S_n$, можно преобразовать в обучающую и тестовую подвыборки: для $i=1,...,n$ тестовая выборка - $S_i$, а обучающая подвыборка есть объединение $S_1$, ..., $S_{i-1}$, $S_{i+1}$, ... $S_n$.\n",
    "\n",
    "**Основные этапы алгоритма**: \n",
    "\n",
    "1) *Всем подвыборкам одинаковое количество элементов!* \n",
    "\n",
    "Задать число элементов $c_i$ для каждой подвыборки $S_i$, например, $c_i = \\frac{N}{n},$ где $N$ - это количество примеров в изначальной выборке $D$.\n",
    "\n",
    "2) *Всем подвыборкам одинаковое распределение классов!*  \n",
    "\n",
    "Посчитать число примеров $c^j_i$, содержащих класс $j$ в подвыборке $S_i$: $c^j_i = \\frac{D_j}{n},$ где $D_i$ число элементов исходной выборки, содержащих класс $j$. \n",
    " \n",
    "3) *Раздать все элементы исходной выборки, начиная с самых редких в самые нуждающиеся подвыборки!* \n",
    "\n",
    "Далее, элементы исходной выборки распределяются по подвыборкам следующим образом:\n",
    "  - Из оставшихся элементов $D$ найти такие $(x, y)$ с самым редко встречающимся классом (исключая ноль, как отсутствие класса). Назовем найденный класс $l$.\n",
    "  - Из заполненных подвыборок найти такую $S_i$, в которой присутствует наименьшее количество элементов с выбранным классом $l$: $i = \\text{argmax}_k(c^l_k)$\n",
    "  - Занести элемент в подвыборку $S_i$ и убрать из исходной выборки $D$. А также внести правки в $c^j_i$ и $c_i$.\n",
    "\n",
    "Более подробное описание алгоритма, а также разбор нюансов можно найти в статье [авторов алгоритма](http://lpis.csd.auth.gr/publications/tsoumakas-ijdwm.pdf). Далее в практической части, рассмотрим реализацию этого алгоритма в интерактивном режиме.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Практика\n",
    "\n",
    "В этой части мы рассмотрим итеративную стратификацию на конкретном примере. Мы будем использовать библиотеку [iterative-stratification](https://github.com/trent-b/iterative-stratification). Чтобы установить ее, просто выполните команду:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip3 install iterative-stratification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Итак, возьмем для примера выборку `stackoverflow_sample_125k.tsv`, которая состоит из вопросов и тегов. Мы имеем 125000 вопросов и 10 уникальных тегов : 'android', 'c#', 'c++', 'html', 'ios', 'java', 'javascript', 'jquery', 'php', 'python'.\n",
    "\n",
    "Загрузим данные и добавим теги отдельными колонками."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pylab as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv(\"../../data/stackoverflow_sample_125k.tsv\", sep=\"\\t\", header=None)\n",
    "data.columns = [\"sample\", \"tags\"]\n",
    "unique_tags = ['android', 'c#', 'c++', 'html', 'ios', 'java', 'javascript', 'jquery', 'php', 'python']\n",
    "tags = np.zeros((len(data), len(unique_tags)), dtype=np.uint8)\n",
    "\n",
    "for i, vals in enumerate(data['tags'].str.split(\" \").values):\n",
    "    for v in vals:\n",
    "        tags[i, unique_tags.index(v)] = 1\n",
    "\n",
    "data = pd.concat([data, pd.DataFrame(tags, columns=unique_tags)], axis=1)\n",
    "\n",
    "print(data.shape)\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "На самом деле, данная выборка настолько хорошо сбалансирована, что даже при случайном разбиении на поднаборы распределение классов визуально одинаковое:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import KFold\n",
    "\n",
    "\n",
    "def display_kfold_split(data):\n",
    "\n",
    "    splt = KFold(n_splits=5, shuffle=True, random_state=12345)\n",
    "\n",
    "    y = np.arange(len(data))\n",
    "    X = np.zeros((len(data), 1))\n",
    "\n",
    "    for train_index, test_index in splt.split(X, y):\n",
    "        break\n",
    "\n",
    "    tags_counts = data[unique_tags].sum(axis=0)\n",
    "    tags_counts = tags_counts / tags_counts.max()\n",
    "\n",
    "    train_tags_counts = data.loc[train_index, unique_tags].sum(axis=0)\n",
    "    train_tags_counts = train_tags_counts / train_tags_counts.max()\n",
    "\n",
    "    test_tags_counts = data.loc[test_index, unique_tags].sum(axis=0)\n",
    "    test_tags_counts = test_tags_counts / test_tags_counts.max()\n",
    "\n",
    "    fig = plt.figure(figsize=(15, 5))\n",
    "    plt.suptitle(\"Распределение тегов при KFold разбиении\")\n",
    "    plt.subplot(1,3,1)\n",
    "    plt.title(\"Исходное распределение\")\n",
    "    sns.barplot(x=tags_counts, y=unique_tags, orient='h')\n",
    "    plt.subplot(1,3,2)\n",
    "    plt.title(\"Тренировочное распределение\")\n",
    "    sns.barplot(x=train_tags_counts, y=unique_tags, orient='h')\n",
    "    plt.subplot(1,3,3)\n",
    "    plt.title(\"Проверочное распределение\")\n",
    "    sns.barplot(x=test_tags_counts, y=unique_tags, orient='h')\n",
    "    \n",
    "display_kfold_split(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Исправим эту ситуацию:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(12345)\n",
    "indices = np.random.choice(np.arange(len(data)), size=2000)\n",
    "small_data = data.loc[indices, :].reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "display_kfold_split(small_data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Теперь, как видно, подвыборки имеют разные распределения тегов."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "В библиотеке [iterative-stratification](https://github.com/trent-b/iterative-stratification) реализованы следующие классы :\n",
    "- `MultilabelStratifiedKFold` (наследованый от `sklearn.model_selection._split._BaseKFold`)\n",
    "- `MultilabelStratifiedShuffleSplit` (наследованый от `sklearn.model_selection._split.BaseShuffleSplit`)\n",
    "- `RepeatedMultilabelStratifiedKFold` (наследованый от `sklearn.model_selection._split._RepeatedSplits`)\n",
    "\n",
    "\n",
    "Вначале запустим стратификацию и посмотрим на полученный результат."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from iterstrat.ml_stratifiers import MultilabelStratifiedKFold\n",
    "\n",
    "splt = MultilabelStratifiedKFold(n_splits=5, shuffle=True, random_state=12345)\n",
    "\n",
    "y = small_data[unique_tags]\n",
    "X = np.zeros((len(small_data), 1))\n",
    "\n",
    "for train_index, test_index in splt.split(X, y):\n",
    "    break\n",
    "\n",
    "tags_counts = small_data[unique_tags].sum(axis=0)\n",
    "tags_counts = tags_counts / tags_counts.max()\n",
    "\n",
    "train_tags_counts = small_data.loc[train_index, unique_tags].sum(axis=0)\n",
    "train_tags_counts = train_tags_counts / train_tags_counts.max()\n",
    "\n",
    "test_tags_counts = small_data.loc[test_index, unique_tags].sum(axis=0)\n",
    "test_tags_counts = test_tags_counts / test_tags_counts.max()\n",
    "\n",
    "fig = plt.figure(figsize=(15, 5))\n",
    "plt.suptitle(\"Распределение тегов при MultilabelStratifiedKFold разбиении\")\n",
    "plt.subplot(1,3,1)\n",
    "plt.title(\"Исходное распределение\")\n",
    "sns.barplot(x=tags_counts, y=unique_tags, orient='h')\n",
    "plt.subplot(1,3,2)\n",
    "plt.title(\"Тренировочное распределение\")\n",
    "sns.barplot(x=train_tags_counts, y=unique_tags, orient='h')\n",
    "plt.subplot(1,3,3)\n",
    "plt.title(\"Проверочное распределение\")\n",
    "sns.barplot(x=test_tags_counts, y=unique_tags, orient='h')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Мы видим, что алгоритм сохраняет распределение классов в подвыборках. И это хорошо!\n",
    "\n",
    "Теперь, разберем в деталях реализацию алгоритма и интерактивно посмотрим как происходит разбиение. Итак, у нас на входе `y` и количество разбиений `n_splits`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_splits = 5\n",
    "n_samples = y.shape[0]\n",
    "y.shape, n_splits"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Сначала оценим сколько элементов исходной выборки может попасть в каждый сплит (подвыборку):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "r = np.asarray([1 / n_splits] * n_splits)\n",
    "c_folds = r * n_samples\n",
    "c_folds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Это теоретическое значение является верхней границей.\n",
    "Затем оценим количество элементов каждого класса в каждой подвыборке. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c_folds_labels = np.outer(r, y.sum(axis=0))\n",
    "c_folds_labels"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Таким образом, мы оценили, что количество элементов, содержащих тег \"android\" в каждой тестовой подвыборке (из 5 фолдов): "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c_folds_labels[:, 0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Теперь посмотрим в интерактивном режиме как заполняются подвыборки.\n",
    "\n",
    "\n",
    "*В следующей клетке должна появиться кнопка \"Вперед\", при нажатии которой, графики распределения тегов обновляются на каждом этапе поиска редкий тегов. Данный функционал требует наличие установленного [`ipywidgets`](https://ipywidgets.readthedocs.io/en/latest/index.html). Если, кнопка не появляется, то возможно нужно установить библиотку и включить дополнения:*\n",
    "```\n",
    "pip install ipywidgets\n",
    "jupyter nbextension enable --py widgetsnbextension\n",
    "```\n",
    "*В докере festline данный пакет присутствует и вышеописанная визуализация функционирует :)*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import display\n",
    "from ipywidgets import interactive\n",
    "\n",
    "random_state = np.random.RandomState(12345)\n",
    "labels = np.asarray(y.values, dtype=bool)\n",
    "folds = np.zeros((n_splits, labels.shape[1]), dtype=np.int)\n",
    "remaining_labels_mask = np.ones(n_samples, dtype=bool)\n",
    "r = np.asarray([1 / n_splits] * n_splits)\n",
    "c_folds = r * n_samples\n",
    "c_folds_labels = np.outer(r, labels.sum(axis=0))\n",
    "\n",
    "\n",
    "def draw():\n",
    "    plt.figure(figsize=(20, 5))\n",
    "    for i in range(n_splits):\n",
    "        plt.subplot(1,n_splits, i + 1)\n",
    "        plt.title(\"Подвыборка {}\".format(i))\n",
    "        sns.barplot(x=folds[i, :], y=unique_tags, orient='h')\n",
    "        plt.xlim((0, 100))\n",
    "        if i > 0:\n",
    "            plt.yticks([])\n",
    "\n",
    "            \n",
    "def go_forward():\n",
    "    if not np.any(remaining_labels_mask):\n",
    "        print(\"Конец процедуры\")\n",
    "        return\n",
    "    num_labels = labels[remaining_labels_mask].sum(axis=0)\n",
    "    if num_labels.sum() == 0:\n",
    "        sample_idxs = np.where(remaining_labels_mask)[0]\n",
    "        for sample_idx in sample_idxs:\n",
    "            fold_idx = np.where(c_folds == c_folds.max())[0]\n",
    "            if fold_idx.shape[0] > 1:\n",
    "                fold_idx = [fold_idx[random_state.choice(fold_idx.shape[0])]]\n",
    "            folds[fold_idx[0], :] += labels[sample_idx]\n",
    "            c_folds[fold_idx] -= 1\n",
    "        return    \n",
    "    label_idx = np.where(num_labels == num_labels[np.nonzero(num_labels)].min())[0]\n",
    "    if label_idx.shape[0] > 1:\n",
    "        label_idx = label_idx[random_state.choice(label_idx.shape[0])]\n",
    "    sample_idxs = np.where(np.logical_and(labels[:, label_idx].flatten(), remaining_labels_mask))[0]\n",
    "\n",
    "    print(\"Найден 'редкий' класс: {} и {} элементов, сожержащих этот класс\"\n",
    "          .format(unique_tags[label_idx[0]], len(sample_idxs)))\n",
    "    print(\"Элементы будут распределены по нуждающимся выборкам\")\n",
    "    for sample_idx in sample_idxs:\n",
    "        label_folds = c_folds_labels[:, label_idx]\n",
    "        fold_idx = np.where(label_folds == label_folds.max())[0]\n",
    "        if fold_idx.shape[0] > 1:\n",
    "            temp_fold_idx = np.where(c_folds[fold_idx] == c_folds[fold_idx].max())[0]\n",
    "            fold_idx = fold_idx[temp_fold_idx]\n",
    "            if temp_fold_idx.shape[0] > 1:\n",
    "                fold_idx = [fold_idx[random_state.choice(temp_fold_idx.shape[0])]]\n",
    "                \n",
    "        folds[fold_idx[0], :] += labels[sample_idx]        \n",
    "        remaining_labels_mask[sample_idx] = False        \n",
    "        c_folds_labels[fold_idx, labels[sample_idx]] -= 1\n",
    "        c_folds[fold_idx] -= 1\n",
    "\n",
    "    print(\"Остаток элементов по классам в подвыборке: \\n{}\".format(c_folds_labels))\n",
    "\n",
    "    \n",
    "def interact_draw():\n",
    "    go_forward()\n",
    "    draw()\n",
    "\n",
    "    \n",
    "options = dict(manual=True, auto_display=True)\n",
    "interact_manual = interactive(interact_draw, options)\n",
    "interact_manual.manual_button.description = \"Вперед\"\n",
    "interact_manual"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Заключение\n",
    "\n",
    "Цель данного обзора - познакомиться с одним из методов разбиения данных с пересекающимися классами. Естественно, что качество исходного набора будет влиять на качество стратификации. \n",
    "Плюсы алгоритма (от авторов алгоритма): \n",
    "- сохраняет дисбаланс классов в подвыборках\n",
    "- применимся в случае, когда уникальных комбинаций классов соизмеримо много с количеством данных\n",
    "\n",
    "### Еще о применении\n",
    "\n",
    "Данный функционал может быть также использован для разбиения датасетов с изображениями и разметкой в виде а) мультиклассовая сегментационная маска объектов или б) мультиклассовые прямоугольники.\n",
    "Для этих данных алгоритм можно усовершенствовать и добавить размер объекта как \"вес\" класса для каждого элемента данных.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Ссылки:\n",
    "\n",
    "- http://lpis.csd.auth.gr/publications/sechidis-ecmlpkdd-2011.pdf\n",
    "- https://www.slideshare.net/tsoumakas/on-the-stratification-of-multilabel-data\n",
    "- https://github.com/trent-b/iterative-stratification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
